{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Explain XGB on TBI signal data\n",
    "\n",
    "Explain XGBoost model trained on raw data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tbi_downstream_prediction import *\n",
    "PATH = \"/homes/gws/hughchen/phase/tbi_subset/\"\n",
    "DPATH = PATH+\"tbi/processed_data/hypoxemia/\"\n",
    "RESULTPATH = PATH+\"results/\"\n",
    "\n",
    "# Set important variables\n",
    "label_type = \"desat_bool92_5_nodesat\"; eta = 0.02\n",
    "hosp_data = \"tbi\"; data_type = \"raw[top11]\"\n",
    "mod_type = \"xgb_{}_eta{}\".format(label_type,eta)\n",
    "\n",
    "# Set up result directory\n",
    "RESDIR = '{}results/{}/'.format(PATH, mod_type)\n",
    "if not os.path.exists(RESDIR): os.makedirs(RESDIR)\n",
    "\n",
    "# Load tbi data\n",
    "y_tbi = np.load(DPATH+\"tbiy.npy\",mmap_mode=\"r\")\n",
    "X_tbi = np.load(DPATH+\"X_tbi_imp_standard.npy\",mmap_mode=\"r\")\n",
    "\n",
    "feats = [\"ECGRATE\", \"ETCO2\", \"ETSEV\", \"ETSEVO\", \"FIO2\", \"NIBPD\", \"NIBPM\", \n",
    "         \"NIBPS\",\"PEAK\", \"PEEP\", \"PIP\", \"RESPRATE\", \"SAO2\", \"TEMP1\", \"TV\"]\n",
    "weird_feats = [\"ETSEV\", \"PIP\", \"PEEP\", \"TV\"]\n",
    "feat_inds = np.array([feats.index(feat) for feat in feats if feat not in weird_feats])    \n",
    "X_tbi2 = X_tbi[:,feat_inds,:]\n",
    "(X_test, y_test, X_valid, y_valid, X_train, y_train) = split_data(DPATH,X_tbi2,y_tbi)\n",
    "feat_lst2 = [feat for feat in feats if feat not in weird_feats]\n",
    "\n",
    "# Set parameters to train model\n",
    "param = {'max_depth':6, 'eta':eta, 'subsample':0.5, 'gamma':1.0, \n",
    "         'min_child_weight':10, 'base_score':y_train.mean(), \n",
    "         'objective':'binary:logistic', 'eval_metric':[\"logloss\"]}\n",
    "\n",
    "# Generate attributions\n",
    "import shap\n",
    "save_path = RESDIR+\"hosp{}_data/{}/\".format(hosp_data,data_type)\n",
    "if DEBUG: print(\"[DEBUG] Loading model from {}\".format(save_path))\n",
    "bst = xgb.Booster()\n",
    "bst.load_model(save_path+'mod_eta{}.model'.format(param['eta']))\n",
    "\n",
    "np.random.seed(10)\n",
    "background_inds = np.random.choice(np.arange(0,X_train.shape[0]),1000)\n",
    "X_train_background = X_train[background_inds]\n",
    "ind_explainer = shap.TreeExplainer(bst,data=X_train_background,feature_dependence=\"independent\")\n",
    "X_train_ind_explanations = ind_explainer.shap_values(X_train)\n",
    "np.save(save_path+\"X_train_explanations_ind\",X_train_ind_explanations)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Plot the attributions for each time point"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_path = RESDIR+\"hosp{}_data/{}/\".format(hosp_data,data_type)\n",
    "import itertools\n",
    "feat_lst2_expanded = [[f+\"_\"+str(i) for i in range(0,60)] for f in feat_lst2]\n",
    "feat_lst2_expanded = list(itertools.chain.from_iterable(feat_lst2_expanded))\n",
    "shap.summary_plot(X_train_ind_explanations,features=X_train,feature_names=feat_lst2_expanded,max_display=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Plot the aggregated attributions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train_explanations_agg = np.transpose(np.vstack([X_train_ind_explanations[:,i*60:(i+1)*60].sum(1) for i in range(len(feat_lst2))]))\n",
    "X_train_agg = np.transpose(np.vstack([X_train[:,i*60:(i+1)*60].sum(1) for i in range(len(feat_lst2))]))\n",
    "shap.summary_plot(X_train_explanations_agg,features=X_train_agg,feature_names=feat_lst2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plot the time series signal and attribution for a single sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "DPATH = \"/homes/gws/hughchen/phase/tbi_subset/tbi/processed_data/hypoxemia/\"\n",
    "\n",
    "feat_lst = [\"ECGRATE\", \"ETCO2\", \"ETSEV\", \"ETSEVO\", \"FIO2\", \"NIBPD\", \"NIBPM\", \n",
    "            \"NIBPS\",\"PEAK\", \"PEEP\", \"PIP\", \"RESPRATE\", \"SAO2\", \"TEMP1\", \"TV\"]\n",
    "\n",
    "# The data used for training is normalized.  Here we are getting the unnormalized\n",
    "# but imputed data for the sake of plotting.\n",
    "X_tbi = np.load(DPATH+\"tbiX.npy\",mmap_mode=\"r\")\n",
    "\n",
    "# Impute missing data\n",
    "def impute_data(dataX):\n",
    "    imp_val = np.median(dataX[dataX != 0])\n",
    "    if np.isnan(imp_val): imp_val = 0\n",
    "    dataX[dataX == 0] = imp_val\n",
    "    return(dataX)\n",
    "\n",
    "X_tbi_imp = []\n",
    "for feat in feat_lst:\n",
    "    fpath = \"/homes/gws/hughchen/RNN/LSTM_Feature/code/\"\n",
    "    fpath += \"min5_data/{}minimum5/hospital_0/raw/\".format(feat)\n",
    "    fname = [f for f in os.listdir(fpath) if feat+\".npy\" in f and \"X_train_val\" in f][0]\n",
    "    X_trval = np.load(fpath+fname)\n",
    "    X_tbi_curr = np.copy(X_tbi[:,feat_lst.index(feat),:])\n",
    "    if feat == \"TV\": X_tbi_curr = X_tbi_curr/1000.0\n",
    "\n",
    "    X_trval_imp = impute_data(X_trval)\n",
    "    X_tbi_curr[np.isnan(X_tbi_curr)] = 0\n",
    "    X_tbi_curr_imp = impute_data(X_tbi_curr)\n",
    "    X_tbi_imp.append(X_tbi_curr_imp)\n",
    "\n",
    "X_tbi_imp = np.stack(X_tbi_imp,axis=1)\n",
    "\n",
    "(_, _, _, _, X_train_raw_imp, _) = split_data(DPATH,X_tbi_imp,y_tbi,flatten=False)\n",
    "\n",
    "# FIO2_59_ind = ((60*(feat_lst2.index(\"FIO2\")+1))-1)\n",
    "# FIO2_59_exp = X_train_ind_explanations[:,FIO2_59_ind]\n",
    "# best_samp_ind = np.where(FIO2_59_exp==FIO2_59_exp.max())[0][0]\n",
    "import matplotlib.pyplot as plt\n",
    "plt.rcParams[\"figure.figsize\"] = (5,2)\n",
    "\n",
    "samp_ind = np.random.choice(X_train.shape[0])\n",
    "# samp_ind = 189432\n",
    "samp_ind = 14898\n",
    "\n",
    "signame = \"TEMP1\"\n",
    "ind = feat_lst2.index(signame)\n",
    "# plt.plot(X_train_raw_imp[feats.index(signame)][samp_ind])\n",
    "plt.plot(X_train_raw_imp[samp_ind,:,feats.index(signame)])\n",
    "# plt.plot(X_train[samp_ind,(60*ind):(60*(ind+1))])\n",
    "plt.ylabel(signame)\n",
    "plt.savefig(\"fig/tempvalu{}.pdf\".format(samp_ind))\n",
    "plt.show()\n",
    "plt.plot(X_train_ind_explanations[samp_ind,(60*ind):(60*(ind+1))])\n",
    "plt.savefig(\"fig/tempattr{}.pdf\".format(samp_ind))\n",
    "plt.show()\n",
    "\n",
    "signame = \"SAO2\"\n",
    "ind = feat_lst2.index(signame)\n",
    "plt.plot(X_train_raw_imp[samp_ind,:,feats.index(signame)])\n",
    "plt.ylabel(signame)\n",
    "plt.savefig(\"fig/sao2valu{}.pdf\".format(samp_ind))\n",
    "plt.show()\n",
    "plt.plot(X_train_ind_explanations[samp_ind,(60*ind):(60*(ind+1))])\n",
    "plt.savefig(\"fig/sao2attr{}.pdf\".format(samp_ind))\n",
    "plt.show()\n",
    "\n",
    "signame = \"FIO2\"\n",
    "ind = feat_lst2.index(signame)\n",
    "# plt.plot(X_train_raw_imp[feats.index(signame)][samp_ind])\n",
    "plt.plot(X_train_raw_imp[samp_ind,:,feats.index(signame)])\n",
    "# plt.plot(X_train[samp_ind,(60*ind):(60*(ind+1))])\n",
    "plt.ylabel(signame)\n",
    "plt.savefig(\"fig/fio2valu{}.pdf\".format(samp_ind))\n",
    "plt.show()\n",
    "plt.plot(X_train_ind_explanations[samp_ind,(60*ind):(60*(ind+1))])\n",
    "plt.savefig(\"fig/fio2attr{}.pdf\".format(samp_ind))\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
